import logging
from typing import Literal

from .._ast import ToolCallNode
from .._grammar import capture, quote_regex
from .._grammar import gen as grammar_gen
from .._grammar import regex as regex_node
from .._guidance import guidance

logger = logging.getLogger(__name__)


def gen(
    name=None,
    *,
    max_tokens=None,
    list_append=False,
    regex=None,
    tools=None,
    hide_tool_call=False,
    stop: str | list[str] | None = None,
    stop_regex: str | list[str] | None = None,
    suffix: str | None = None,
    n=1,
    temperature=None,
    top_p=1.0,
    save_stop_text: bool | str = False,
    tool_choice: Literal["auto", "required"] = "auto",
):
    """Generate a set of tokens until a given stop criteria has been met.

    This function is a useful utility that can allow you to specify most grammars used by typical
    LM generation programs. It also has the added ability to interleave generation with tool calls.

        >>> lm += gen("my_generation", max_tokens=10)
        >>> print(lm["my_generation"])
        some text from the LLM

    Parameters
    ----------

        name : str or None
            If this is not None then the the results of the generation will be saved as a variable on
            the Model object (so you can access the result as `lm["var_name"]`).

        max_tokens : int
            The maximum number of generation tokens we should use. Note that this limit is not exact when
            regular expression pattern constraints are present, but guidance does attempt to end the generation
            as soon as possible while keeping the regex constraints satisfied.

        list_append : bool
            If this is True then the results saved to `lm[name]` will not be written directly but rather appended
            to a list (if no list with the current name is present one will be created). This is useful for
            building lists inside python loops.

        regex : str or None
            This is a regular expression that will be used to constrain the generation. The model is only allowed
            to generate tokens that match this regular expression. Note that for variable length expressions the
            model is free to continue the expression after a complete match, but generation will terminate as soon
            as the model generates anything that does not match the pattern (this ending behavior may change a bit we
            update guidance to maintain the grammar parsing state between calls).

        stop : str or list or None
            The stop string (or list of strings) we should use for terminating this generation segment.

        stop_regex : str or list or None
            The stop regular expression (or list of regular expressions) we should use for terminating this generation segment.

        save_stop_text : bool or str
            If True then this saves the captured stop text or regex into a variable of the name `str(name) + "_stop_text"`. If
            a string is given then the captured stop text is saved under that name.

        temperature : float
            The temperature to use during this generation call. Note that when parsing ambiguous grammars that include
            multiple conflicting temperatures (for example from multiple possible `gen` calls inside a `select`) the highest
            temperature of all options is used by the model (since we only want to run the model once, not once for every
            possible parse path).

        top_p : float
            TODO! Will control the models top_p generation parameter, but has been yet been implemented beyond top_p=1.0.

        n : int
            TODO! Will control the number of parallel generation calls made during gen.

        tools : Tool or list or None
            A list of `guidance.Tool` or Python functions (which will be converted to `guidance.Tool`).
            When using `tools` you must specify `max_tokens` and cannot use `regex`.

        hide_tool_call : bool
            Controls if we should hide the text generated by the model to trigger a tool call. You may want to hide the tool
            call from the model's context if you plan to change its format after the call is made.
    """
    if hide_tool_call:
        raise ValueError("`hide_tool_call` is deprecated")
    if tools:
        if stop:
            raise ValueError("Cannot use `stop` with `tools`")
        if stop_regex:
            raise ValueError("Cannot use `stop_regex` with `tools`")
        if suffix:
            raise ValueError("Cannot use `suffix` with `tools`")
        if temperature is not None:
            raise NotImplementedError("`temperature` is not supported with `tools` yet")
        if max_tokens is not None:
            raise NotImplementedError("`max_tokens` is not supported with `tools` yet")
        if name is not None:
            raise NotImplementedError("`name` is not supported with `tools` yet")

        @guidance(stateless=False, dedent=False)
        def tool_gen(lm):
            return lm + ToolCallNode.from_tools(
                tools=tools,
                tool_choice=tool_choice,
                parallel_tool_calls=False,  # TODO: support parallel tool calls
                plaintext_regex=regex,
            )

        return tool_gen()

    assert n == 1, "We still need to add support for n>1! Consider putting your gen call in a loop for now."
    assert top_p == 1, "Please use `model.with_sampling_params` to set top_p."

    logger.debug(f'start gen(name="{name}")')

    if stop is not None and stop_regex is not None:
        raise ValueError("Cannot use both stop and stop_regex")
    if isinstance(stop, list):
        stop_regex = [quote_regex(s) for s in stop]
        stop = None
    if isinstance(stop_regex, list):
        stop_regex = "|".join(list(stop_regex))

    if save_stop_text is False:
        save_stop_name = None
    elif save_stop_text is True:
        # TODO: "None_stop_text" -- is that really what we want?
        save_stop_name = str(name) + "_stop_text"
    else:
        save_stop_name = save_stop_text

    return grammar_gen(
        regex=regex,
        stop_regex=stop_regex,
        stop=stop,
        suffix=suffix,
        stop_capture=save_stop_name,
        name=name,
        list_append=list_append,
        temperature=temperature,
        max_tokens=max_tokens,
    )


@guidance(stateless=True)
def regex(lm, pattern, *, name=None):
    node = regex_node(pattern)
    if name:
        node = capture(node, name)
    return lm + node
